{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Initialize dependencies and get data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using TensorFlow backend.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Downloading data from https://s3.amazonaws.com/img-datasets/mnist.npz\n",
      "11444224/11490434 [============================>.] - ETA: 0s"
     ]
    }
   ],
   "source": [
    "import keras\n",
    "from keras.datasets import mnist\n",
    "from keras.models import Sequential\n",
    "from keras.layers import Dense, Dropout, Flatten\n",
    "from keras.layers import Conv2D, MaxPooling2D\n",
    "from keras import backend as K\n",
    "\n",
    "batch_size = 128\n",
    "num_classes = 10\n",
    "epochs = 12\n",
    "\n",
    "# input image dimensions\n",
    "img_rows, img_cols = 28, 28\n",
    "\n",
    "# the data, shuffled and split between train and test sets\n",
    "(x_train, y_train), (x_test, y_test) = mnist.load_data()\n",
    "\n",
    "if K.image_data_format() == 'channels_first':\n",
    "    x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols)\n",
    "    x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols)\n",
    "    input_shape = (1, img_rows, img_cols)\n",
    "else:\n",
    "    x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1)\n",
    "    x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1)\n",
    "    input_shape = (img_rows, img_cols, 1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Pre-process image data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "x_train shape: (60000, 28, 28, 1)\n",
      "60000 train samples\n",
      "10000 test samples\n"
     ]
    }
   ],
   "source": [
    "x_train = x_train.astype('float32')\n",
    "x_test = x_test.astype('float32')\n",
    "x_train /= 255\n",
    "x_test /= 255\n",
    "print('x_train shape:', x_train.shape)\n",
    "print(x_train.shape[0], 'train samples')\n",
    "print(x_test.shape[0], 'test samples')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "# convert class vectors to binary class matrices\n",
    "y_train = keras.utils.to_categorical(y_train, num_classes)\n",
    "y_test = keras.utils.to_categorical(y_test, num_classes)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Build a CNN based deep neural network"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "model = Sequential()\n",
    "model.add(Conv2D(32, kernel_size=(3, 3),\n",
    "                 activation='relu',\n",
    "                 input_shape=input_shape))\n",
    "model.add(Conv2D(64, (3, 3), activation='relu'))\n",
    "model.add(MaxPooling2D(pool_size=(2, 2)))\n",
    "model.add(Dropout(0.25))\n",
    "model.add(Flatten())\n",
    "model.add(Dense(128, activation='relu'))\n",
    "model.add(Dropout(0.5))\n",
    "model.add(Dense(num_classes, activation='softmax'))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Visualize the network architecture"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "image/svg+xml": [
       "<svg height=\"719pt\" viewBox=\"0.00 0.00 392.00 719.00\" width=\"392pt\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
       "<g class=\"graph\" id=\"graph0\" transform=\"scale(1 1) rotate(0) translate(4 715)\">\n",
       "<title>G</title>\n",
       "<polygon fill=\"white\" points=\"-4,4 -4,-715 388,-715 388,4 -4,4\" stroke=\"none\"/>\n",
       "<!-- 1648868526512 -->\n",
       "<g class=\"node\" id=\"node1\"><title>1648868526512</title>\n",
       "<polygon fill=\"none\" points=\"20,-664.5 20,-710.5 364,-710.5 364,-664.5 20,-664.5\" stroke=\"black\"/>\n",
       "<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"108\" y=\"-683.8\">conv2d_1_input: InputLayer</text>\n",
       "<polyline fill=\"none\" points=\"196,-664.5 196,-710.5 \" stroke=\"black\"/>\n",
       "<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"224\" y=\"-695.3\">input:</text>\n",
       "<polyline fill=\"none\" points=\"196,-687.5 252,-687.5 \" stroke=\"black\"/>\n",
       "<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"224\" y=\"-672.3\">output:</text>\n",
       "<polyline fill=\"none\" points=\"252,-664.5 252,-710.5 \" stroke=\"black\"/>\n",
       "<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"308\" y=\"-695.3\">(None, 28, 28, 1)</text>\n",
       "<polyline fill=\"none\" points=\"252,-687.5 364,-687.5 \" stroke=\"black\"/>\n",
       "<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"308\" y=\"-672.3\">(None, 28, 28, 1)</text>\n",
       "</g>\n",
       "<!-- 1648868526176 -->\n",
       "<g class=\"node\" id=\"node2\"><title>1648868526176</title>\n",
       "<polygon fill=\"none\" points=\"41,-581.5 41,-627.5 343,-627.5 343,-581.5 41,-581.5\" stroke=\"black\"/>\n",
       "<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"104.5\" y=\"-600.8\">conv2d_1: Conv2D</text>\n",
       "<polyline fill=\"none\" points=\"168,-581.5 168,-627.5 \" stroke=\"black\"/>\n",
       "<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"196\" y=\"-612.3\">input:</text>\n",
       "<polyline fill=\"none\" points=\"168,-604.5 224,-604.5 \" stroke=\"black\"/>\n",
       "<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"196\" y=\"-589.3\">output:</text>\n",
       "<polyline fill=\"none\" points=\"224,-581.5 224,-627.5 \" stroke=\"black\"/>\n",
       "<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"283.5\" y=\"-612.3\">(None, 28, 28, 1)</text>\n",
       "<polyline fill=\"none\" points=\"224,-604.5 343,-604.5 \" stroke=\"black\"/>\n",
       "<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"283.5\" y=\"-589.3\">(None, 26, 26, 32)</text>\n",
       "</g>\n",
       "<!-- 1648868526512&#45;&gt;1648868526176 -->\n",
       "<g class=\"edge\" id=\"edge1\"><title>1648868526512-&gt;1648868526176</title>\n",
       "<path d=\"M192,-664.366C192,-656.152 192,-646.658 192,-637.725\" fill=\"none\" stroke=\"black\"/>\n",
       "<polygon fill=\"black\" points=\"195.5,-637.607 192,-627.607 188.5,-637.607 195.5,-637.607\" stroke=\"black\"/>\n",
       "</g>\n",
       "<!-- 1648868702080 -->\n",
       "<g class=\"node\" id=\"node3\"><title>1648868702080</title>\n",
       "<polygon fill=\"none\" points=\"41,-498.5 41,-544.5 343,-544.5 343,-498.5 41,-498.5\" stroke=\"black\"/>\n",
       "<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"104.5\" y=\"-517.8\">conv2d_2: Conv2D</text>\n",
       "<polyline fill=\"none\" points=\"168,-498.5 168,-544.5 \" stroke=\"black\"/>\n",
       "<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"196\" y=\"-529.3\">input:</text>\n",
       "<polyline fill=\"none\" points=\"168,-521.5 224,-521.5 \" stroke=\"black\"/>\n",
       "<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"196\" y=\"-506.3\">output:</text>\n",
       "<polyline fill=\"none\" points=\"224,-498.5 224,-544.5 \" stroke=\"black\"/>\n",
       "<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"283.5\" y=\"-529.3\">(None, 26, 26, 32)</text>\n",
       "<polyline fill=\"none\" points=\"224,-521.5 343,-521.5 \" stroke=\"black\"/>\n",
       "<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"283.5\" y=\"-506.3\">(None, 24, 24, 64)</text>\n",
       "</g>\n",
       "<!-- 1648868526176&#45;&gt;1648868702080 -->\n",
       "<g class=\"edge\" id=\"edge2\"><title>1648868526176-&gt;1648868702080</title>\n",
       "<path d=\"M192,-581.366C192,-573.152 192,-563.658 192,-554.725\" fill=\"none\" stroke=\"black\"/>\n",
       "<polygon fill=\"black\" points=\"195.5,-554.607 192,-544.607 188.5,-554.607 195.5,-554.607\" stroke=\"black\"/>\n",
       "</g>\n",
       "<!-- 1648868701968 -->\n",
       "<g class=\"node\" id=\"node4\"><title>1648868701968</title>\n",
       "<polygon fill=\"none\" points=\"0,-415.5 0,-461.5 384,-461.5 384,-415.5 0,-415.5\" stroke=\"black\"/>\n",
       "<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"104.5\" y=\"-434.8\">max_pooling2d_1: MaxPooling2D</text>\n",
       "<polyline fill=\"none\" points=\"209,-415.5 209,-461.5 \" stroke=\"black\"/>\n",
       "<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"237\" y=\"-446.3\">input:</text>\n",
       "<polyline fill=\"none\" points=\"209,-438.5 265,-438.5 \" stroke=\"black\"/>\n",
       "<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"237\" y=\"-423.3\">output:</text>\n",
       "<polyline fill=\"none\" points=\"265,-415.5 265,-461.5 \" stroke=\"black\"/>\n",
       "<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"324.5\" y=\"-446.3\">(None, 24, 24, 64)</text>\n",
       "<polyline fill=\"none\" points=\"265,-438.5 384,-438.5 \" stroke=\"black\"/>\n",
       "<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"324.5\" y=\"-423.3\">(None, 12, 12, 64)</text>\n",
       "</g>\n",
       "<!-- 1648868702080&#45;&gt;1648868701968 -->\n",
       "<g class=\"edge\" id=\"edge3\"><title>1648868702080-&gt;1648868701968</title>\n",
       "<path d=\"M192,-498.366C192,-490.152 192,-480.658 192,-471.725\" fill=\"none\" stroke=\"black\"/>\n",
       "<polygon fill=\"black\" points=\"195.5,-471.607 192,-461.607 188.5,-471.607 195.5,-471.607\" stroke=\"black\"/>\n",
       "</g>\n",
       "<!-- 1648868661288 -->\n",
       "<g class=\"node\" id=\"node5\"><title>1648868661288</title>\n",
       "<polygon fill=\"none\" points=\"39.5,-332.5 39.5,-378.5 344.5,-378.5 344.5,-332.5 39.5,-332.5\" stroke=\"black\"/>\n",
       "<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"104.5\" y=\"-351.8\">dropout_1: Dropout</text>\n",
       "<polyline fill=\"none\" points=\"169.5,-332.5 169.5,-378.5 \" stroke=\"black\"/>\n",
       "<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"197.5\" y=\"-363.3\">input:</text>\n",
       "<polyline fill=\"none\" points=\"169.5,-355.5 225.5,-355.5 \" stroke=\"black\"/>\n",
       "<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"197.5\" y=\"-340.3\">output:</text>\n",
       "<polyline fill=\"none\" points=\"225.5,-332.5 225.5,-378.5 \" stroke=\"black\"/>\n",
       "<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"285\" y=\"-363.3\">(None, 12, 12, 64)</text>\n",
       "<polyline fill=\"none\" points=\"225.5,-355.5 344.5,-355.5 \" stroke=\"black\"/>\n",
       "<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"285\" y=\"-340.3\">(None, 12, 12, 64)</text>\n",
       "</g>\n",
       "<!-- 1648868701968&#45;&gt;1648868661288 -->\n",
       "<g class=\"edge\" id=\"edge4\"><title>1648868701968-&gt;1648868661288</title>\n",
       "<path d=\"M192,-415.366C192,-407.152 192,-397.658 192,-388.725\" fill=\"none\" stroke=\"black\"/>\n",
       "<polygon fill=\"black\" points=\"195.5,-388.607 192,-378.607 188.5,-388.607 195.5,-388.607\" stroke=\"black\"/>\n",
       "</g>\n",
       "<!-- 1648849747640 -->\n",
       "<g class=\"node\" id=\"node6\"><title>1648849747640</title>\n",
       "<polygon fill=\"none\" points=\"50,-249.5 50,-295.5 334,-295.5 334,-249.5 50,-249.5\" stroke=\"black\"/>\n",
       "<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"104.5\" y=\"-268.8\">flatten_1: Flatten</text>\n",
       "<polyline fill=\"none\" points=\"159,-249.5 159,-295.5 \" stroke=\"black\"/>\n",
       "<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"187\" y=\"-280.3\">input:</text>\n",
       "<polyline fill=\"none\" points=\"159,-272.5 215,-272.5 \" stroke=\"black\"/>\n",
       "<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"187\" y=\"-257.3\">output:</text>\n",
       "<polyline fill=\"none\" points=\"215,-249.5 215,-295.5 \" stroke=\"black\"/>\n",
       "<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"274.5\" y=\"-280.3\">(None, 12, 12, 64)</text>\n",
       "<polyline fill=\"none\" points=\"215,-272.5 334,-272.5 \" stroke=\"black\"/>\n",
       "<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"274.5\" y=\"-257.3\">(None, 9216)</text>\n",
       "</g>\n",
       "<!-- 1648868661288&#45;&gt;1648849747640 -->\n",
       "<g class=\"edge\" id=\"edge5\"><title>1648868661288-&gt;1648849747640</title>\n",
       "<path d=\"M192,-332.366C192,-324.152 192,-314.658 192,-305.725\" fill=\"none\" stroke=\"black\"/>\n",
       "<polygon fill=\"black\" points=\"195.5,-305.607 192,-295.607 188.5,-305.607 195.5,-305.607\" stroke=\"black\"/>\n",
       "</g>\n",
       "<!-- 1648849891904 -->\n",
       "<g class=\"node\" id=\"node7\"><title>1648849891904</title>\n",
       "<polygon fill=\"none\" points=\"67,-166.5 67,-212.5 317,-212.5 317,-166.5 67,-166.5\" stroke=\"black\"/>\n",
       "<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"119\" y=\"-185.8\">dense_1: Dense</text>\n",
       "<polyline fill=\"none\" points=\"171,-166.5 171,-212.5 \" stroke=\"black\"/>\n",
       "<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"199\" y=\"-197.3\">input:</text>\n",
       "<polyline fill=\"none\" points=\"171,-189.5 227,-189.5 \" stroke=\"black\"/>\n",
       "<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"199\" y=\"-174.3\">output:</text>\n",
       "<polyline fill=\"none\" points=\"227,-166.5 227,-212.5 \" stroke=\"black\"/>\n",
       "<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"272\" y=\"-197.3\">(None, 9216)</text>\n",
       "<polyline fill=\"none\" points=\"227,-189.5 317,-189.5 \" stroke=\"black\"/>\n",
       "<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"272\" y=\"-174.3\">(None, 128)</text>\n",
       "</g>\n",
       "<!-- 1648849747640&#45;&gt;1648849891904 -->\n",
       "<g class=\"edge\" id=\"edge6\"><title>1648849747640-&gt;1648849891904</title>\n",
       "<path d=\"M192,-249.366C192,-241.152 192,-231.658 192,-222.725\" fill=\"none\" stroke=\"black\"/>\n",
       "<polygon fill=\"black\" points=\"195.5,-222.607 192,-212.607 188.5,-222.607 195.5,-222.607\" stroke=\"black\"/>\n",
       "</g>\n",
       "<!-- 1648849845104 -->\n",
       "<g class=\"node\" id=\"node8\"><title>1648849845104</title>\n",
       "<polygon fill=\"none\" points=\"57.5,-83.5 57.5,-129.5 326.5,-129.5 326.5,-83.5 57.5,-83.5\" stroke=\"black\"/>\n",
       "<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"122.5\" y=\"-102.8\">dropout_2: Dropout</text>\n",
       "<polyline fill=\"none\" points=\"187.5,-83.5 187.5,-129.5 \" stroke=\"black\"/>\n",
       "<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"215.5\" y=\"-114.3\">input:</text>\n",
       "<polyline fill=\"none\" points=\"187.5,-106.5 243.5,-106.5 \" stroke=\"black\"/>\n",
       "<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"215.5\" y=\"-91.3\">output:</text>\n",
       "<polyline fill=\"none\" points=\"243.5,-83.5 243.5,-129.5 \" stroke=\"black\"/>\n",
       "<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"285\" y=\"-114.3\">(None, 128)</text>\n",
       "<polyline fill=\"none\" points=\"243.5,-106.5 326.5,-106.5 \" stroke=\"black\"/>\n",
       "<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"285\" y=\"-91.3\">(None, 128)</text>\n",
       "</g>\n",
       "<!-- 1648849891904&#45;&gt;1648849845104 -->\n",
       "<g class=\"edge\" id=\"edge7\"><title>1648849891904-&gt;1648849845104</title>\n",
       "<path d=\"M192,-166.366C192,-158.152 192,-148.658 192,-139.725\" fill=\"none\" stroke=\"black\"/>\n",
       "<polygon fill=\"black\" points=\"195.5,-139.607 192,-129.607 188.5,-139.607 195.5,-139.607\" stroke=\"black\"/>\n",
       "</g>\n",
       "<!-- 1648849975784 -->\n",
       "<g class=\"node\" id=\"node9\"><title>1648849975784</title>\n",
       "<polygon fill=\"none\" points=\"70.5,-0.5 70.5,-46.5 313.5,-46.5 313.5,-0.5 70.5,-0.5\" stroke=\"black\"/>\n",
       "<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"122.5\" y=\"-19.8\">dense_2: Dense</text>\n",
       "<polyline fill=\"none\" points=\"174.5,-0.5 174.5,-46.5 \" stroke=\"black\"/>\n",
       "<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"202.5\" y=\"-31.3\">input:</text>\n",
       "<polyline fill=\"none\" points=\"174.5,-23.5 230.5,-23.5 \" stroke=\"black\"/>\n",
       "<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"202.5\" y=\"-8.3\">output:</text>\n",
       "<polyline fill=\"none\" points=\"230.5,-0.5 230.5,-46.5 \" stroke=\"black\"/>\n",
       "<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"272\" y=\"-31.3\">(None, 128)</text>\n",
       "<polyline fill=\"none\" points=\"230.5,-23.5 313.5,-23.5 \" stroke=\"black\"/>\n",
       "<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"272\" y=\"-8.3\">(None, 10)</text>\n",
       "</g>\n",
       "<!-- 1648849845104&#45;&gt;1648849975784 -->\n",
       "<g class=\"edge\" id=\"edge8\"><title>1648849845104-&gt;1648849975784</title>\n",
       "<path d=\"M192,-83.3664C192,-75.1516 192,-65.6579 192,-56.7252\" fill=\"none\" stroke=\"black\"/>\n",
       "<polygon fill=\"black\" points=\"195.5,-56.6068 192,-46.6068 188.5,-56.6069 195.5,-56.6068\" stroke=\"black\"/>\n",
       "</g>\n",
       "</g>\n",
       "</svg>"
      ],
      "text/plain": [
       "<IPython.core.display.SVG object>"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from IPython.display import SVG\n",
    "from keras.utils.vis_utils import model_to_dot\n",
    "\n",
    "SVG(model_to_dot(model, show_shapes=True, \n",
    "                 show_layer_names=True, rankdir='TB').create(prog='dot', format='svg'))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Build the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "model.compile(loss=keras.losses.categorical_crossentropy,\n",
    "              optimizer=keras.optimizers.Adadelta(),\n",
    "              metrics=['accuracy'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train on 60000 samples, validate on 10000 samples\n",
      "Epoch 1/2\n",
      "60000/60000 [==============================] - 230s - loss: 0.1867 - acc: 0.9444 - val_loss: 0.0692 - val_acc: 0.9786\n",
      "Epoch 2/2\n",
      "60000/60000 [==============================] - 232s - loss: 0.1061 - acc: 0.9682 - val_loss: 0.0524 - val_acc: 0.9835\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<keras.callbacks.History at 0x17fe7308240>"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.fit(x_train, y_train,\n",
    "          batch_size=batch_size,\n",
    "          epochs=2,\n",
    "          verbose=1,\n",
    "          validation_data=(x_test, y_test))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Predict and test model performance"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "10000/10000 [==============================] - 13s    \n"
     ]
    }
   ],
   "source": [
    "score = model.evaluate(x_test, y_test, verbose=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Test loss: 0.0523672130647\n",
      "Test accuracy: 0.9835\n"
     ]
    }
   ],
   "source": [
    "print('Test loss:', score[0])\n",
    "print('Test accuracy:', score[1])"
   ]
  }
 ],
 "metadata": {
  "anaconda-cloud": {},
  "kernelspec": {
   "display_name": "Python 2",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
